from srunner.scenariomanager.traffic_events import TrafficEvent, TrafficEventType
from srunner.scenariomanager.scenarioatomics.atomic_criteria import CollisionTest

class Evaluator(object):
    """
    Provides feedback on an episode based on the scenario criteria
    """
    def __init__(self, env):
        self.env = env
        self.metrics = {}

    def get_feedback(self):
        criteria = self.env.scenario.get_criteria()
        results = []
        for criterion in criteria:
            name = criterion.name
            if criterion.optional: name += " (optional)"
            else: name += " (required)"
            actor = f"{criterion.actor.type_id[8:]} (id={criterion.actor.id})"
            criterion_result = {"actor": actor,
                                "criterion_name": name,
                                "status": criterion.test_status,
                                "value": criterion.actual_value,
                                "expected_value": criterion.success_value}
            results.append(criterion_result)
        return results

    def get_agent_feedback(self, agent):
        criteria = self.env.scenario.get_criteria()
        results = []
        for criterion in criteria:
            if criterion.actor.id != agent._vehicle_id:
                continue
            name = criterion.name
            if criterion.optional: name += " (optional)"
            else: name += " (required)"
            actor = f"{criterion.actor.type_id[8:]} (id={criterion.actor.id})"
            criterion_result = {"actor": actor,
                                "criterion_name": name,
                                "status": criterion.test_status,
                                "value": criterion.actual_value,
                                "expected_value": criterion.success_value}
            results.append(criterion_result)
        return results

    def get_agent_termination(self, agent):
        criteria_results = self.get_agent_feedback(agent)
        if not criteria_results:
            # If an agent has no criteria, it must not be focal agents
            return False
        for result in criteria_results:
            if result["status"] == 'FAILURE':
                return True
            if 'RouteCompletion' in result['criterion_name'] and result["status"] == 'SUCCESS':
                return True
        return False

    def get_agent_reward(self, agent):
        criteria_results = self.get_agent_feedback(agent)
        reward = 0
        collision = 0
        for result in criteria_results:
            if 'RouteCompletion' in result['criterion_name'] and result["status"] == 'SUCCESS':
                reward = 1
            if 'Collision' in result['criterion_name'] and result["status"] == 'FAILURE':
                reward = -1
                collision = 1

        # Update episodic cumulative metrics
        if agent._agent_id not in self.metrics:
            self.metrics[agent._agent_id] = {}
            self.metrics[agent._agent_id]['reward'] = reward
            self.metrics[agent._agent_id]['collision'] = collision
        else:
            self.metrics[agent._agent_id]['reward'] += reward
            self.metrics[agent._agent_id]['collision'] = max(self.metrics[agent._agent_id]['collision'], collision)
        return reward

    def get_agent_collision(self, agent):
        return self.metrics[agent._agent_id]['collision']

    def get_agent_episode_reward(self, agent):
        return self.metrics[agent._agent_id]['reward']

    def get_agent_route_completion(self, agent):
        criteria_results = self.get_agent_feedback(agent)
        for result in criteria_results:
            if 'RouteCompletion' in result['criterion_name']:
                return float(result["value"])
        return 0

    def get_language_feedback(self):
        agents = self.env.agents
        vehicle_ids = [agent._vehicle_id for agent in agents]
        focal_agents = self.env.focal_agents
        focal_vehicle_ids = [agent._vehicle_id for agent in focal_agents]

        collision_info = {}
        collision_occurred = False
        collision_details = []

        # Calculate collisions, who collided with whom
        criteria = self.env.scenario.get_criteria()
        for criterion in criteria:
            if isinstance(criterion, CollisionTest):
                if criterion.test_status == 'FAILURE':
                    # Extract collision events from the criterion
                    for event in criterion.events:
                        if event.get_type() in [
                            TrafficEventType.COLLISION_STATIC,
                            TrafficEventType.COLLISION_VEHICLE,
                            TrafficEventType.COLLISION_PEDESTRIAN
                        ]:
                            collider_id = criterion.actor.id
                            other_actor = event.get_dict().get('other_actor', None)
                            collidee_id = other_actor.id if other_actor else 'Unknown'
                            # Only consider collisions involving focal vehicles
                            if (collider_id not in focal_vehicle_ids and
                                collidee_id not in focal_vehicle_ids
                            ):
                                continue
                            collision_occurred = True
                            collision_location = event.get_dict().get('location', None)
                            collision_time = self.env.step_count # Simulator time, descete frame count
                            collision_details.append({
                                'collider': collider_id,
                                'collidee': collidee_id,
                                'location': {
                                    'x': collision_location.x if collision_location else None,
                                    'y': collision_location.y if collision_location else None,
                                    'z': collision_location.z if collision_location else None
                                },
                                'time': collision_time
                            })

        collision_info['collision_occurred'] = collision_occurred
        collision_info['collision_time'] = collision_details[-1]['time'] if collision_details else None
        collision_info['collision_details'] = collision_details
        collision_description = ""
        if collision_occurred:
            for collision_detail in collision_details:
                if collision_detail['collider'] in focal_vehicle_ids and collision_detail['collidee']=='Unknown':
                    collision_description += f"Vehicle {collision_detail['collider']} ran into a collision.\n"
                elif collision_detail['collider'] in focal_vehicle_ids:
                    collision_description += f"Vehicle {collision_detail['collider']} collided with Vehicle {collision_detail['collidee']}.\n"
        else:
            collision_description = "No collision occurred."
        collision_info['collision_description'] = collision_description

        stagnation_info = {}
        stagnation_occurred = True if self.env.step_count >= self.env.max_episode_length * self.env.frame_rate else False
        stagnation_info['stagnation_occurred'] = stagnation_occurred
        stagnation_info['stagnation_time'] = self.env.step_count / self.env.frame_rate if stagnation_occurred else None
        if stagnation_occurred:
            focal_alive_vehicle_id = ''
            for agent in focal_agents:
                if agent.is_alive:
                    focal_alive_vehicle_id = agent._vehicle_id
                    break
            stagnation_info['stagnation_description'] = f"The episode reached the maximum time limit because Vehicle {focal_alive_vehicle_id} spent too long in completing its task."
        else:
            stagnation_info['stagnation_description'] = "The episode did not reach the maximum time limit."

        success = not (collision_occurred or stagnation_occurred)

        feedback = {
            'success': success,
            'collision_info': collision_info,
            'stagnation_info': stagnation_info,
            'focal_vehicle_ids': focal_vehicle_ids,
            'vehicle_ids': vehicle_ids,
            'episode_length': self.env.step_count,
            'agent_vehicle_mapping': {agent._agent_id: agent._vehicle_id for agent in agents},
        }
        return feedback